Coverage for cpprb/LaBER.py: 95%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

37 statements  

1import numpy as np 

2 

3class LaBER: 

4 def __init__(self, batch_size: int, m: int = 4, *, eps: float = 1e-6): 

5 """ 

6 Initialize LaBER (sub-)class 

7 

8 Parameters 

9 ---------- 

10 batch_size : int 

11 Batch size for neural network 

12 m : int, optional 

13 Multiplication factor. ``m * batch_size`` transitions will be passed. 

14 Default value is ``4``. 

15 eps : float, option 

16 Small positive values to avoid 0 priority. Default value is ``1e-6``. 

17 

18 

19 Raises 

20 ------ 

21 ValueError 

22 When ``batch_size <= 0``, ``m <= 0``, or ``eps <= 0``. 

23 """ 

24 self.rng = np.random.default_rng() 

25 

26 self.batch_size = int(batch_size) 

27 if self.batch_size <= 0: 

28 raise ValueError("``batch_size`` must be positive integer.") 

29 

30 if m <= 0: 

31 raise ValueError("``m`` must be positive integer") 

32 

33 self.idx = np.arange(int(self.batch_size * m)) 

34 

35 self.eps = float(eps) 

36 if self.eps < 0: 

37 raise ValueError("``eps`` must be non negative") 

38 

39 def __call__(self, *, priorities, **kwargs): 

40 """ 

41 Sub-sample from large batch 

42 

43 Parameters 

44 ---------- 

45 priorities : array-like of float 

46 Surrogate priorities. 

47 **kwargs : key-value 

48 Large batch sampled from ``ReplayBuffer``. These values are also 

49 included sub-sampled batch. 

50 

51 Returns 

52 ------- 

53 dict 

54 Sub-sampled batch, which includes ``"weights"``, ``"indexes"``, 

55 and passed keys. 

56 

57 Raises 

58 ------ 

59 ValueError 

60 If the size of ``priorities`` is not ``batch_size * m``. 

61 """ 

62 p = np.asarray(priorities) + self.eps 

63 if p.shape != self.idx.shape: 

64 raise ValueError("``priorities`` size must be ``batch_size * m``") 

65 

66 p = p / p.sum() 

67 

68 _idx = self.rng.choice(self.idx, self.batch_size, p=p) 

69 

70 if kwargs is None: 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true

71 kwargs = {} 

72 else: 

73 kwargs = {k: v[_idx] for k, v in kwargs.items()} 

74 

75 kwargs["weights"] = self._normalize_weight(p, _idx) 

76 kwargs["indexes"] = _idx 

77 

78 return kwargs 

79 

80 def _normalize_weight(self, p, _idx): 

81 raise NotImplementedError 

82 

83 

84class LaBERmean(LaBER): 

85 """ 

86 Helper class for Large Batch Experience Replay (LaBER) 

87 

88 This helper class is a functor designed to be used together with ``ReplayBuffer``. 

89 It takes surrogate priorities for large batch, then returns sub-sampled indexes 

90 and weights. 

91 

92 See Also 

93 -------- 

94 LaBERmax, LaBERlazy : Other variants 

95 

96 Notes 

97 ----- 

98 In LaBER [1]_, first m-times larger batch (large bacth) is sampled from 

99 Replay Buffer. The final mini-batch is sampled from the large batch based on 

100 newly calculated surrogate priorities. 

101 This class implements LaBER-mean variant, where weights are normalized by 

102 average over the large batch. 

103 

104 References 

105 ---------- 

106 .. [1] T. Lahire et al, "Large Batch Experience Replay", CoRR (2021) 

107 https://dblp.org/db/journals/corr/corr2110.html#journals/corr/abs-2110-01528 

108 https://arxiv.org/abs/2110.01528 

109 """ 

110 def _normalize_weight(self, p, _idx): 

111 return p.mean() / p[_idx] 

112 

113 

114class LaBERlazy(LaBER): 

115 """ 

116 Helper class for Large Batch Experience Replay (LaBER) 

117 

118 This helper class is a functor designed to be used together with ``ReplayBuffer``. 

119 It takes surrogate priorities for large batch, then returns sub-sampled indexes 

120 and weights. 

121 

122 Warnings 

123 -------- 

124 According to the proposed paper [1]_, ``LaBERmean`` is more preferable. 

125 

126 See Also 

127 -------- 

128 LaBERmean, LaBERmax : Other variants 

129 

130 Notes 

131 ----- 

132 In LaBER [1]_, first m-times larger batch (large bacth) is sampled from 

133 Replay Buffer. The final mini-batch is sampled from the large batch based on 

134 newly calculated surrogate priorities. 

135 This class implements LaBER-lazy variant, where weights are not normalized at all. 

136 

137 References 

138 ---------- 

139 .. [1] T. Lahire et al, "Large Batch Experience Replay", CoRR (2021) 

140 https://dblp.org/db/journals/corr/corr2110.html#journals/corr/abs-2110-01528 

141 https://arxiv.org/abs/2110.01528 

142 """ 

143 def _normalize_weight(self, p, _idx): 

144 return 1.0 / p[_idx] 

145 

146 

147class LaBERmax(LaBER): 

148 """ 

149 Helper class for Large Batch Experience Replay (LaBER) 

150 

151 This helper class is a functor designed to be used together with ``ReplayBuffer``. 

152 It takes surrogate priorities for large batch, then returns sub-sampled indexes 

153 and weights. 

154 

155 Warnings 

156 -------- 

157 According to the proposed paper [1]_, ``LaBERmean`` is more preferable. 

158 

159 See Also 

160 -------- 

161 LaBERmean, LaBERlazy : Other variants 

162 

163 Notes 

164 ----- 

165 In LaBER [1]_, first m-times larger batch (large bacth) is sampled from 

166 Replay Buffer. The final mini-batch is sampled from the large batch based on 

167 newly calculated surrogate priorities. 

168 This class implements LaBER-max variant, where weights are normalized by 

169 the maximum weight of selected mini-batch. 

170 

171 References 

172 ---------- 

173 .. [1] T. Lahire et al, "Large Batch Experience Replay", CoRR (2021) 

174 https://dblp.org/db/journals/corr/corr2110.html#journals/corr/abs-2110-01528 

175 https://arxiv.org/abs/2110.01528 

176 """ 

177 def _normalize_weight(self, p, _idx): 

178 p_idx = 1.0 / p[_idx] 

179 return p_idx / p_idx.max()